import torch
from model.model import *

if __name__ == '__main__':

    model_name = 'ResGPT2.pt'
    device = torch.device("cuda")  # 如果可用，使用GPU进行加速
    model = GPT().to(device)  # 将模型移动到GPU上

    try:
        model.load_state_dict(torch.load(model_name))
    except FileNotFoundError:
        print(f"无法找到GPT-2.0模型文件{model_name}，请检查文件是否存在。")
        input('按回车键退出...')
        exit()

    model.eval()

    sentence = ''
    while True:
        temp_sentence = input("User:")
        sentence += (temp_sentence + '\t')
        if len(sentence) > 200:
            t_index = sentence.find('\t')
            sentence = sentence[t_index + 1:]
        print("ResGPT2:", model.answer(sentence), "\n")
        